from tqdm import tqdm


def train(net, train_loader, optimizer, criterion, writer, args, epoch, index_num):
    net.train()
    train_tqdm = tqdm(train_loader, desc="Epoch " + str(epoch))
    for index, (inputs, labels) in enumerate(train_tqdm):
        optimizer.zero_grad()
        outputs = net(inputs.to(args.device))
        loss = criterion(outputs, labels.to(args.device))
        loss.backward()
        optimizer.step()
        writer.add_scalar("loss/train", loss, index_num)
        index_num = index_num + 1
        train_tqdm.set_postfix({"loss": "%.3g" % loss.item()})
